{ "cells": [ { "cell_type": "markdown", "id": "32szzPY4RyWO", "metadata": { "id": "32szzPY4RyWO" }, "source": [ "### **4. R learner**\n", "The idea of classical R-learner came from Robinson 1988 [3] and was formalized by Nie and Wager in 2020 [2]. The main idea of R learner starts from the partially linear model setup, in which we assume that\n", "\\begin{equation}\n", " \\begin{aligned}\n", " R&=A\\tau(S)+g_0(S)+U,\\\\\n", " A&=m_0(S)+V,\n", " \\end{aligned}\n", "\\end{equation}\n", "where $U$ and $V$ satisfies $\\mathbb{E}[U|D,X]=0$, $\\mathbb{E}[V|X]=0$.\n", "\n", "After several manipulations, it’s easy to get\n", "\\begin{equation}\n", "\tR-\\mathbb{E}[R|S]=\\tau(S)\\cdot(A-\\mathbb{E}[A|S])+\\epsilon.\n", "\\end{equation}\n", "Define $m_0(X)=\\mathbb{E}[A|S]$ and $l_0(X)=\\mathbb{E}[R|S]$. A natural way to estimate $\\tau(X)$ is given below, which is also the main idea of R-learner:\n", "\n", "**Step 1**: Regress $R$ on $S$ to obtain model $\\hat{\\eta}(S)=\\hat{\\mathbb{E}}[R|S]$; and regress $A$ on $S$ to obtain model $\\hat{m}(S)=\\hat{\\mathbb{E}}[A|S]$.\n", "\n", "**Step 2**: Regress outcome residual $R-\\hat{l}(S)$ on propensity score residual $A-\\hat{m}(S)$.\n", "\n", "That is,\n", "\\begin{equation}\n", "\t\\hat{\\tau}(S)=\\arg\\min_{\\tau}\\left\\{\\mathbb{E}_n\\left[\\left(\\{R_i-\\hat{\\eta}(S_i)\\}-\\{A_i-\\hat{m}(S_i)\\}\\cdot\\tau(S_i)\\right)^2\\right]\\right\\}\t\n", "\\end{equation}\n", "\n", "The easiest way to do so is to specify $\\hat{\\tau}(S)$ to the linear function class. In this case, $\\tau(S)=S\\beta$, and the problem becomes to estimate $\\beta$ by solving the following linear regression:\n", "\\begin{equation}\n", "\t\\hat{\\beta}=\\arg\\min_{\\beta}\\left\\{\\mathbb{E}_n\\left[\\left(\\{R_i-\\hat{\\eta}(S_i)\\}-\\{A_i-\\hat{m}(S_i)\\} S_i\\cdot \\beta\\right)^2\\right]\\right\\}.\n", "\\end{equation}\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "eRpP5k9MBtzO", "metadata": { "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.linear_model import LogisticRegression \n", "\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL\n", "from causaldm.learners.CEL.Single_Stage.Rlearner import Rlearner\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "JhfJntzcVVy2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", " | user_id | \n", "movie_id | \n", "rating | \n", "age | \n", "Drama | \n", "Sci-Fi | \n", "gender_M | \n", "occupation_academic/educator | \n", "occupation_college/grad student | \n", "occupation_executive/managerial | \n", "occupation_other | \n", "occupation_technician/engineer | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "48.0 | \n", "1193.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "48.0 | \n", "919.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "48.0 | \n", "527.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "48.0 | \n", "1721.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "48.0 | \n", "150.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
65637 | \n", "5878.0 | \n", "3300.0 | \n", "2.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65638 | \n", "5878.0 | \n", "1391.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65639 | \n", "5878.0 | \n", "185.0 | \n", "4.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65640 | \n", "5878.0 | \n", "2232.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65641 | \n", "5878.0 | \n", "426.0 | \n", "3.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65642 rows × 12 columns
\n", "